# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

import os
import sys
import json
from tqdm import tqdm

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))

os.environ["FLAGS_allocator_strategy"] = 'auto_growth'

import paddle
from paddle.jit import to_static

from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list
from ppocr.utils.visual import draw_rectangle
from tools.infer.utility import draw_boxes
import tools.program as program
import cv2


@paddle.no_grad()
def main(config, device, logger, vdl_writer):
    global_config = config['Global']

    # build post process
    post_process_class = build_post_process(config['PostProcess'],
                                            global_config)

    # build model
    if hasattr(post_process_class, 'character'):
        config['Architecture']["Head"]['out_channels'] = len(
            getattr(post_process_class, 'character'))

    model = build_model(config['Architecture'])
    algorithm = config['Architecture']['algorithm']

    load_model(config, model)

    # create data ops
    transforms = []
    for op in config['Eval']['dataset']['transforms']:
        op_name = list(op)[0]
        if 'Encode' in op_name:
            continue
        if op_name == 'KeepKeys':
            op[op_name]['keep_keys'] = ['image', 'shape']
        transforms.append(op)

    global_config['infer_mode'] = True
    ops = create_operators(transforms, global_config)

    save_res_path = config['Global']['save_res_path']
    os.makedirs(save_res_path, exist_ok=True)

    model.eval()
    bs = 96
    with open(os.path.join(save_res_path, 'SLANet.json'), mode='w',encoding='utf-8') as f_w:
        for i in tqdm(range(len(get_image_file_list(config['Global']['infer_img'])))[::bs]):
            images = []
            shape_list = []
            file_list = []
            for file in get_image_file_list(config['Global']['infer_img'])[i:i+bs]:
                file_list.append(file)
                with open(file, 'rb') as f:
                    img = f.read()
                    data = {'image': img}
                batch = transform(data, ops)
                images.append(batch[0])
                shape_list.append(batch[1])
            images = paddle.to_tensor(images)
            preds = model(images)
            post_result = post_process_class(preds, [shape_list])
            for file, structure_str_list in zip(file_list, post_result['structure_batch_list']):
                structure_str_list = structure_str_list[0]
                res_dict = {}
                res_dict['filename'] = 'testA/' + file.split('/')[-1]
                res_dict['html'] = structure_str_list
                f_w.write(json.dumps(res_dict) + '\n')
        logger.info("success!")

if __name__ == '__main__':
    config, device, logger, vdl_writer = program.preprocess()
    main(config, device, logger, vdl_writer)
